In [2]:
using PyPlot, Interact
addprocs(12);
push!(LOAD_PATH, "../src/")
using HDStat
Our model for high-dimensional static decoding of neural activities: $$ y^T = w^T \left(UX_0 + Z \right) + \epsilon^T $$
We have
The generative model's parameters are $(N, K, P, r, s)$, we will use
$$X = UX_0 + Z$$to denote neural activities.
We model the observations of neural activities as $$ \hat{X} = S\left(UX_0 + Z \right) $$ where $S$ is a $M$-by-$N$ random sampling matrix. Additionally, we also measure the behavior output $y$.
The observation model's parameter is simply $M$.
In [3]:
@everywhere immutable GenModel
N::Integer
K::Integer
P::Integer
r::Number
s::Number
w::Array{Float64, 1}
U::Array{Float64, 2}
function GenModel(N::Integer, K::Integer, P::Integer, r::Number, s::Number)
U = randn(N, K) / sqrt(N)
Up, _ = qr(U)
w = Up * randn(K)
w /= norm(w)
return new(N, K, P, r, s, w, U)
end
end
@everywhere immutable ObsModel
gen::GenModel
M::Integer
S::Array{Float64, 2}
function ObsModel(gen::GenModel, M::Integer)
S = eye(gen.N)[randperm(gen.N)[1:M], :]
return new(gen, M, S)
end
end
In [4]:
@everywhere function rand(model::GenModel)
let N = model.N, K = model.K, P = model.P, r = model.r, s = model.s, w = model.w, U = model.U
X0 = qr(randn(P, K))[1]' * r * sqrt(P / K)
Z = randn(N, P) / sqrt(N)
ϵ = randn(P) * s
y = vec(w' * (U * X0 + Z) + ϵ')
return {:X0 => X0, :X => model.U * X0 + Z, :Z => Z, :e => ϵ, :y => y}
end
end
@everywhere function rand(model::ObsModel)
rst = rand(model.gen)
rst[:Xhat] = model.S * rst[:X]
return rst
end
The inferred $K$ from the data is the number of $\hat{X}$'s singular values above the output noise floor
$$ \frac{\sqrt{P} + \sqrt{M}}{\sqrt{N}} $$The correct $K$ is inferred when the minimum singular value of sample neural activities is greater than the input noise floor, or,
$$ r\sqrt{\frac{P}{K}}\left(\sqrt{\frac{M}{N}} - \sqrt{\frac{K}{N}}\right) \geq \frac{\left(MP\right)^{1/4}}{\sqrt{N}} \Rightarrow r^2\frac{\sqrt{MP}}{K}\left(1 - \sqrt{\frac{K}{M}}\right)^2 \geq 1 $$
In [5]:
@everywhere function inferK(model::ObsModel; Xhat = None)
if Xhat == None; Xhat = rand(model)[:Xhat]; end
let gen = model.gen
_, S, _ = svd(Xhat)
return sum(S .> (sqrt(gen.P) + sqrt(model.M)) / sqrt(gen.N))
end
end
In [6]:
@everywhere K, N, r, s = 5, 5000, 0.25, 0.0
@everywhere Ms, Ps = 10:10:400, 10:10:400
rst = [@spawn inferK(ObsModel(GenModel(N, K, P, r, s), M)) for M in Ms, P in Ps]
rst = map(fetch, rst);
theory = sqrt(Ms * Ps') / K * r^2 .* (1 - sqrt(K ./ repmat(Ms, 1, length(Ps)))).^2 .> 1;
In [7]:
K, N, r, s = 5, 5000, 0.1, 0.0
M, P = 50, 50
o = ObsModel(GenModel(N, K, P, r, s), M)
x = rand(o)
X0 = x[:X0]
SUX0 = o.S * o.gen.U * x[:X0]
SU = o.S * o.gen.U
println(string("Signal svs:\n", svd(X0)[2]))
println(string("theory:\n", r * sqrt(P)))
tmp = sqrt(eigs(SU * SU'; nev=K)[1])
println("min/max sv of sampling/projection:")
println((minimum(tmp), maximum(tmp)))
println(((sqrt(M) - sqrt(K)) / sqrt(N), (sqrt(M) + sqrt(K)) / sqrt(N)))
tmp = sqrt(eigs(SUX0 * SUX0'; nev=K)[1])
println("min/max sv of sampled/projected signal:")
println((minimum(tmp), maximum(tmp)))
println(((sqrt(M) - sqrt(K)) / sqrt(N) * r * sqrt(P), (sqrt(M) + sqrt(K)) / sqrt(N) * r * sqrt(P)))
println("Threshold:")
println((M * P)^(1/4) / sqrt(N))
In [8]:
figure(figsize=(4, 3))
imshow(rst, aspect="auto", interpolation="nearest", origin="lower", cmap="RdBu_r", extent=[minimum(Ps), maximum(Ps), minimum(Ms), maximum(Ms)]);
colorbar();
contour(repmat(Ps', length(Ms), 1), repmat(Ms, 1, length(Ps)), theory, 1, linewidths=4, colors="k")
Out[8]:
Problem: given $\hat{X}$ and $y$, find $\hat{w}$ such that $|\hat{w}\hat{X} - y^T|_2$ is minimized in a validation dataset
Analysis: what is the angle between the inferred $\hat{w}$ and the sampled ground truth $Sw$?
Algorithm #0: This is a cheating algorithm with $\hat{w} = \alpha Sw$. In other words, we simply find the best scaling of the sampled ground truth decoding vector.
Algorithm #1: Simple linear regression of $y$ against $\hat{X}$.
Algorithm #2: Infer the sampled signal subspace, $\hat{U}$, from $\hat{X}$ first using low-rank perturbation theory, then regress $y$ against $\hat{U}^T\hat{X}$
Algorithm #3: Recover the best sampled signal $\tilde{X}$ from $\hat{X}$ using Gavish and Donoho and the Frobenius error metric, regress $y$ against $\tilde{X}$.
In [9]:
# algorithm #0
@everywhere function cheat_w(Xhat, y, model::ObsModel)
Sw = model.S * model.gen.w
alpha = sum(y .* y) / sum(y .* vec(Sw' * Xhat))
return alpha * Sw
end
# algorithm #1
@everywhere function simple_w(Xhat, y)
return pinv(Xhat * Xhat') * (Xhat * y)
end
# algorithm #2
@everywhere function subspace_w(Xhat, y, model::ObsModel)
let gen = model.gen
thresh = sqrt(gen.P / gen.N) + sqrt(model.M / gen.N)
U, S, V = svd(Xhat)
K = sum(S .> thresh)
if K < 1; return zeros(size(Xhat, 1)); end;
Xtilde = U[:, 1:K]' * Xhat;
return U[:, 1:K] * pinv(Xtilde * Xtilde') * (Xtilde * y)
# return U[:, 1:K] * ((U[:, 1:K]' * Xhat)' \ y)
# return U[:, 1:K] * ((U[:, 1:K]' * Xhat)' \ y)
end
end
# algorithm #3
@everywhere function signal_w(Xhat, y, model::ObsModel)
let gen = model.gen
U, S, V = svd(Xhat)
S = S * sqrt(gen.N / gen.P)
thresh = 1 + sqrt(model.M / gen.P)
beta = model.M / gen.P
mask = S .> thresh
if sum(mask) < 1; return zeros(size(Xhat, 1)); end;
S[mask] = sqrt(S[mask].^2 - beta - 1 + sqrt((S[mask].^2 - beta - 1).^2 - 4 * beta)) / sqrt(2)
S[~mask] = 0
Xtilde = U * diagm(S * sqrt(gen.P / gen.N)) * V'
return pinv(Xtilde * Xtilde') * (Xtilde * y)
end
end
In [10]:
f = figure(figsize=(18, 6))
println("Left four panels shows fitted coefficients against Sw, in the order:")
println("#0, #1")
println("#2, #3")
@manipulate for N in [2000, 5000], M in [50:50:1000], K in [2:4:42], P in [50:50:1000], r in 0.0:0.1:0.5, s in 0:0.1:1
g = GenModel(N, K, P, r, s);
o = ObsModel(g, M);
wS = o.S * g.w
train = rand(o)
ytest = Float64[]
Xtest = zeros(M, 0)
while length(ytest) < 10000
tmp = rand(o)
ytest = [ytest; tmp[:y]]
Xtest = [Xtest tmp[:Xhat]]
end
what_cheat = cheat_w(train[:Xhat], train[:y], o)
what_simple = simple_w(train[:Xhat], train[:y])
what_subspace = subspace_w(train[:Xhat], train[:y], o)
what_signal = signal_w(train[:Xhat], train[:y], o)
angle_cheat = abs(sum(wS .* what_cheat)) / norm(wS) / norm(what_cheat)
angle_simple = abs(sum(wS .* what_simple)) / norm(wS) / norm(what_simple)
angle_subspace = abs(sum(wS .* what_subspace)) / norm(wS) / norm(what_subspace)
angle_signal = abs(sum(wS .* what_signal)) / norm(wS) / norm(what_signal)
err_cheat = norm(vec(what_cheat' * Xtest) - ytest)^2 / norm(ytest)^2
err_simple = norm(vec(what_simple' * Xtest) - ytest)^2 / norm(ytest)^2
err_subspace = norm(vec(what_subspace' * Xtest) - ytest)^2 / norm(ytest)^2
err_signal = norm(vec(what_signal' * Xtest) - ytest)^2 / norm(ytest)^2
withfig(f) do
subplot(261)
plot(wS, what_cheat, ".")
title(string("Inferred K: ", inferK(o; Xhat=train[:Xhat])))
subplot(262)
plot(wS, what_simple, ".")
subplot(2, 6, 7)
plot(wS, what_subspace, ".")
subplot(2, 6, 8)
plot(wS, what_signal, ".")
subplot(132)
bar(1:4, [angle_cheat, angle_simple, angle_subspace, angle_signal])
xticks(1:4, ["#0", "#1", "#2", "#3"]); title("Overlap between Sw and w_hat")
ylim([0, 1])
subplot(133)
bar(1:4, [err_cheat, err_simple, err_subspace, err_signal])
xticks(1:4, ["#0", "#1", "#2", "#3"]); title("Normalized error on held out data");
ylim([1e-2, 1e2]); yscale("log")
end
end
Out[10]:
In [11]:
@everywhere overlap(v1, v2) = abs(sum(v1 .* v2)) / (norm(v1) + eps(Float64)) / (norm(v2) + eps(Float64))
@everywhere testerror(w, X, y) = norm(vec(w' * X) - y)^2 / norm(y)^2
@everywhere function trial(K, N, M, P, r, s)
g = GenModel(N, K, P, r, s);
o = ObsModel(g, M);
wS = o.S * g.w
train = rand(o)
ytest = Float64[]
Xtest = zeros(M, 0)
while length(ytest) < 10000
tmp = rand(o)
ytest = [ytest; tmp[:y]]
Xtest = [Xtest tmp[:Xhat]]
end
what_cheat = cheat_w(train[:Xhat], train[:y], o)
what_simple = simple_w(train[:Xhat], train[:y])
what_subspace = subspace_w(train[:Xhat], train[:y], o)
what_signal = signal_w(train[:Xhat], train[:y], o)
return map(w -> (overlap(wS, w), testerror(w, Xtest, ytest)), {what_cheat, what_simple, what_subspace, what_signal})
end
In [12]:
@everywhere K, N, r, s = 5, 5000, 0.25, 0.0
@everywhere Ms, Ps = 10:10:400, 10:10:400
In [13]:
rst = [@spawn trial(K, N, M, P, r, s) for M = Ms, P = Ps];
rst = map(fetch, rst);
In [14]:
angle_cheat = map(x -> x[1][1], rst);
angle_simple = map(x -> x[2][1], rst);
angle_subspace = map(x -> x[3][1], rst);
angle_signal = map(x -> x[4][1], rst);
err_cheat = map(x -> x[1][2], rst);
err_simple = map(x -> x[2][2], rst);
err_subspace = map(x -> x[3][2], rst);
err_signal = map(x -> x[4][2], rst);
In [15]:
valid = r^2 * sqrt(Ms * Ps') / K .* (1 - sqrt(K ./ repmat(Ms, 1, length(Ps)))).^2 .> 1;
In [16]:
function plot_helper(angle)
extent = [minimum(Ps), maximum(Ps), minimum(Ms), maximum(Ms)]
imshow(angle, aspect="auto", interpolation="nearest", origin="lower", vmin=0, vmax=1, extent=extent);
colorbar();
contour(repmat(Ps', length(Ms), 1), repmat(Ms, 1, length(Ps)), valid, 1, linewidths=4, colors="k")
end
figure(figsize=(8, 6))
subplot(221)
plot_helper(angle_cheat)
title("#0"); ylabel("M")
subplot(222)
plot_helper(angle_simple)
title("#1");
subplot(223)
plot_helper(angle_subspace)
title("#2"); ylabel("M"); xlabel("P")
subplot(224)
plot_helper(angle_signal)
title("#3"); xlabel("P")
Out[16]:
In [19]:
function plot_helper(err)
extent = [minimum(Ps), maximum(Ps), minimum(Ms), maximum(Ms)]
imshow(1 - err, aspect="auto", interpolation="nearest", origin="lower", vmin=0, vmax=1, extent=extent);
colorbar();
contour(repmat(Ps', length(Ms), 1), repmat(Ms, 1, length(Ps)), valid, 1, linewidths=4, colors="k")
end
figure(figsize=(8, 6))
subplot(221)
plot_helper(err_cheat)
title("#0"); ylabel("M")
subplot(222)
plot_helper(err_simple)
title("#1")
subplot(223)
plot_helper(err_subspace)
title("#2"); ylabel("M"); xlabel("P")
subplot(224)
plot_helper(err_signal)
title("#3"); xlabel("P")
Out[19]:
In [19]:
K = 10
f = figure(figsize=(4, 3))
@manipulate for M = 1000:1000:5000, N = 1000:1000:5000
Uorth = qr(randn(N, K))[1][1:M, :]
U = randn(N, K)[1:M, :] / sqrt(N)
Sorth = sort(eig(Uorth' * Uorth)[1])
S = sort(eig(U' * U)[1])
withfig(f) do
plot(S, Sorth, ".")
plot([minimum(S), maximum(S)], [minimum(S), maximum(S)], "k-")
end
end
Out[19]: